import gym
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Type, Union
# import torch as th

# from stable_baselines3.common.base_class import BaseAlgorithm
# from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
# from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
# from stable_baselines3.common.policies import BasePolicy
# from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
# from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
# from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
# from stable_baselines3.common.vec_env import VecEnv
# from stable_baselines3.her.her_replay_buffer import HerReplayBuffer


from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3 import SAC, DDPG


class CustomSAC(SAC):
    def __init__(self, **kwargs):
        super(CustomSAC, self).__init__(**kwargs)
    
    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        # Select action randomly or according to policy
        # if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
        if self.num_timesteps < learning_starts and not (self.use_sde_at_warmup):
            # Warmup phase
            action_noise=None
            unscaled_action = self.env.env_method("sample_action")
            # unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
        else:
            # print("Predicting action")
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            unscaled_action, _ = self.predict(self._last_obs, deterministic=False) 

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, gym.spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action
        return action, buffer_action


class CustomDDPG(DDPG):
    def __init__(self, **kwargs):
        super(CustomDDPG, self).__init__(**kwargs)
    
    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        # Select action randomly or according to policy
        # if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
        # if self.num_timesteps < learning_starts and not (self.use_sde_at_warmup):
        if self.num_timesteps < learning_starts:
            # Warmup phase
            action_noise=None
            unscaled_action = self.env.env_method("sample_action")
            # unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
        else:
            # print("Predicting action")
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            unscaled_action, _ = self.predict(self._last_obs, deterministic=False) 

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, gym.spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action
        return action, buffer_action